import jax.numpy as jnp
from flax.linen.initializers import constant, orthogonal

PRECISION = jnp.float32

# Because orthogonal initialization does not support float16
def orthogonal_wrapped(*init_args):
    def thurn(*args,**kwargs):
        #Skip the dtype argument
        params=orthogonal(*init_args)(*args[:-1],**kwargs)
        return params.astype(PRECISION)
    return thurn